{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Reinforce算法中是以蒙特卡洛采样法来估计Q(state,action)的\n",
    "\n",
    "Actor_Critic算法中以神经网络来估算Q(state,action)\n",
    "\n",
    "再使用td误差来训练该神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAADMCAYAAADTcn7NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAATrklEQVR4nO3dfWxT970G8MfOiyEvxyEJsZebWCAVlUW8bAuQnHVXq4pH1kXTWDNpm1CXVYiqmYNKM3G1SC0VqFdBTLpdu9FwpWrQfzq6TGJTI1oWhTbcCUNKWKYQIGolqkQlxy5wc+ykxEns7/2jN2d1CSxOgn82fj7SkTjn97X9PSfxw/Hv2LFNRARERArYVTdARJmLAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMooC6DDhw9j1apVWLZsGWpqatDb26uqFSJSREkAvfXWW2hpacGLL76IixcvYuPGjairq0MwGFTRDhEpYlPxYdSamhps3rwZv/vd7wAAsVgMlZWV2L17N371q18lux0iUiQ72Q84NTWFvr4+tLa2Wtvsdju8Xi/8fv+ct4lEIohEItZ6LBbDrVu3UFJSApvNdt97JqLEiAjC4TDKy8tht9/9hVbSA+jGjRuIRqNwuVxx210uF65evTrnbdra2rB///5ktEdES2hkZAQVFRV3HU96AC1Ea2srWlparHXTNOHxeDAyMgJN0xR2RkRzCYVCqKysRGFh4T3rkh5ApaWlyMrKQiAQiNseCATgdrvnvI3D4YDD4bhju6ZpDCCiFPavpkiSfhUsNzcX1dXV6O7utrbFYjF0d3dD1/Vkt0NECil5CdbS0oLGxkZs2rQJW7ZswW9+8xtMTEzgqaeeUtEOESmiJIB+/OMf49NPP8W+fftgGAa+9rWv4d13371jYpqIHmxK3ge0WKFQCE6nE6Zpcg6IKAXN9znKz4IRkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhImYQD6MyZM/j+97+P8vJy2Gw2/PnPf44bFxHs27cPX/nKV7B8+XJ4vV58+OGHcTW3bt3Cjh07oGkaioqKsHPnToyPjy9qR4go/SQcQBMTE9i4cSMOHz485/ihQ4fw6quv4siRIzh//jzy8/NRV1eHyclJq2bHjh0YHBxEV1cXOjs7cebMGTz99NML3wsiSk+yCADkxIkT1nosFhO32y2//vWvrW1jY2PicDjkD3/4g4iIXL58WQDIBx98YNW88847YrPZ5JNPPpnX45qmKQDENM3FtE9E98l8n6NLOgd07do1GIYBr9drbXM6naipqYHf7wcA+P1+FBUVYdOmTVaN1+uF3W7H+fPn57zfSCSCUCgUtxBR+lvSADIMAwDgcrnitrtcLmvMMAyUlZXFjWdnZ6O4uNiq+bK2tjY4nU5rqaysXMq2iUiRtLgK1traCtM0rWVkZER1S0S0BJY0gNxuNwAgEAjEbQ8EAtaY2+1GMBiMG5+ZmcGtW7esmi9zOBzQNC1uIaL0t6QBtHr1arjdbnR3d1vbQqEQzp8/D13XAQC6rmNsbAx9fX1WzenTpxGLxVBTU7OU7RBRistO9Abj4+P46KOPrPVr166hv78fxcXF8Hg82LNnD1566SWsWbMGq1evxgsvvIDy8nJs374dAPDVr34V3/3ud7Fr1y4cOXIE09PTaG5uxk9+8hOUl5cv2Y4RURpI9PLae++9JwDuWBobG0Xk80vxL7zwgrhcLnE4HLJ161YZGhqKu4+bN2/KT3/6UykoKBBN0+Spp56ScDi85Jf4iEiN+T5HbSIiCvNvQUKhEJxOJ0zT5HwQUQqa73M0La6CEdGDiQFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKZPw1/IQLbXw6IeYGr9lrTucZchfuQo2m01hV5QMDCBSLjDQjf+9dtFaL13778hfuUpdQ5Q0fAlGKSimugFKEgYQpZ60+6Y6WigGEKWcNPyuTFogBhCloNlv/KYHHQOIUg/PgDIGA4hSjojwBChDJBRAbW1t2Lx5MwoLC1FWVobt27djaGgormZychI+nw8lJSUoKChAQ0MDAoFAXM3w8DDq6+uRl5eHsrIy7N27FzMzM4vfG3pAMH0yRUIB1NPTA5/Ph3PnzqGrqwvT09PYtm0bJiYmrJrnnnsOb7/9Njo6OtDT04Pr16/jiSeesMaj0Sjq6+sxNTWFs2fP4o033sCxY8ewb9++pdsrSm/COaCMIYsQDAYFgPT09IiIyNjYmOTk5EhHR4dVc+XKFQEgfr9fREROnjwpdrtdDMOwatrb20XTNIlEIvN6XNM0BYCYprmY9ilFfHiqXXqP7LKWD//63xKLzqhuixZhvs/RRc0BmaYJACguLgYA9PX1YXp6Gl6v16pZu3YtPB4P/H4/AMDv92P9+vVwuVxWTV1dHUKhEAYHB+d8nEgkglAoFLfQg4xnP5liwQEUi8WwZ88ePPLII1i3bh0AwDAM5ObmoqioKK7W5XLBMAyr5ovhMzs+OzaXtrY2OJ1Oa6msrFxo25QORBhBGWLBAeTz+XDp0iUcP358KfuZU2trK0zTtJaRkZH7/pikjojwUnyGWNCHUZubm9HZ2YkzZ86goqLC2u52uzE1NYWxsbG4s6BAIAC3223V9Pb2xt3f7FWy2ZovczgccDgcC2mV0hLDJ1MkdAYkImhubsaJEydw+vRprF69Om68uroaOTk56O7utrYNDQ1heHgYuq4DAHRdx8DAAILBoFXT1dUFTdNQVVW1mH2hBwXPfjJGQmdAPp8Pb775Jv7yl7+gsLDQmrNxOp1Yvnw5nE4ndu7ciZaWFhQXF0PTNOzevRu6rqO2thYAsG3bNlRVVeHJJ5/EoUOHYBgGnn/+efh8Pp7lEAB+FiyTJBRA7e3tAIBHH300bvvRo0fx85//HADw8ssvw263o6GhAZFIBHV1dXjttdes2qysLHR2dqKpqQm6riM/Px+NjY04cODA4vaEHiB8H1CmsEka/ncTCoXgdDphmiY0TVPdDi3SR389EvcHyZyV6/BQXRPsWTkKu6LFmO9zlJ8Fo5Qj/CxYxmAAUQpi+mQKBhClHn4WLGMwgCjlpOG0JC0QA4hSEAMoUzCASL0vff+XxGI8C8oQDCBSbvmK8rj1SPhTSHRaUTeUTAwgUs6eFf9+WIlFeQaUIRhApJ6Nv4aZij95Uo/fAZ+xGECknI0BlLEYQJQCGECZigFEyvEMKHMxgEg9TkJnLP7kSTmeAWUuBhCpxwDKWAwgUo8BlLEYQKScjb+GGYs/eVKPZ0AZiwFEynESOnMxgEg9BlDGYgCRegygjMUAIuVsfCNixuJPnlIAz4AyVULfjEq0ELFYDOFw+K5/ZOyz27fj1iUmCIVCyIpE56zPzs5Gfn4+J68fAAwguu8Mw8DWrVsRDofnHK9d68J//KjaWr958yZ2fOtbCN+e+8+y6rqOP/7xj/elV0ouBhDdd9FoFKOjozBNc87xG6U5uDH9b/j49nrk2idRMvM/n9dPROauv3HjfrZLSZTQHFB7ezs2bNgATdOgaRp0Xcc777xjjU9OTsLn86GkpAQFBQVoaGhAIBCIu4/h4WHU19cjLy8PZWVl2Lt3L2ZmZpZmbygtjU0X4x/hx3Bj2oPrkTX4R/gxRIXfC58JEgqgiooKHDx4EH19fbhw4QIee+wx/OAHP8Dg4CAA4LnnnsPbb7+Njo4O9PT04Pr163jiiSes20ejUdTX12Nqagpnz57FG2+8gWPHjmHfvn1Lu1eUViajyzEVW/b/azZ8FtMQQ5bSnihJZJFWrFghr7/+uoyNjUlOTo50dHRYY1euXBEA4vf7RUTk5MmTYrfbxTAMq6a9vV00TZNIJDLvxzRNUwCIaZqLbZ+SYHh4WJxO5+z3Ld+xVFc9LK8c7JT9L/XKgZfOyX/tOyxaft5d6x999FGJxWKqd4vuYb7P0QXPAUWjUXR0dGBiYgK6rqOvrw/T09Pwer1Wzdq1a+HxeOD3+1FbWwu/34/169fD5XJZNXV1dWhqasLg4CC+/vWvJ9TD1atXUVBQsNBdoCQxDAPR6NxXtADACF7H2e7/RHBqFXJsEWj4CJGpued/AGBiYgKXL1/mVbAUNj4+Pq+6hANoYGAAuq5jcnISBQUFOHHiBKqqqtDf34/c3FwUFRXF1btcLhiGAeDzX8Qvhs/s+OzY3UQiEUQi//yFDIVCAADTNDl/lAbudQkeAD65EcZbXX4A/nnd38zMzF0ntCk1TExMzKsu4QB6+OGH0d/fD9M08ac//QmNjY3o6elJuMFEtLW1Yf/+/Xdsr6mpgaZp9/WxafFGRkaQnb10F1ydTid0XecZUAqbPUn4VxJ+J3Rubi4eeughVFdXo62tDRs3bsQrr7wCt9uNqakpjI2NxdUHAgG43W4AgNvtvuOq2Oz6bM1cWltbYZqmtYyMjCTaNhGloEV/FCMWiyESiaC6uho5OTno7u62xoaGhjA8PAxd1wF8/gaygYEBBINBq6arqwuapqGqququj+FwOKxL/7MLEaW/hM6LW1tb8fjjj8Pj8SAcDuPNN9/E+++/j1OnTsHpdGLnzp1oaWlBcXExNE3D7t27oes6amtrAQDbtm1DVVUVnnzySRw6dAiGYeD555+Hz+eDw+G4LztIRKkroQAKBoP42c9+htHRUTidTmzYsAGnTp3Cd77zHQDAyy+/DLvdjoaGBkQiEdTV1eG1116zbp+VlYXOzk40NTVB13Xk5+ejsbERBw4cWNq9opRit9tRWFh4z4noROTn5y/J/ZB6Nlmq34okCoVCcDqdME2TL8fSwMzMDAzDWLIAcjgcWLlyJSehU9h8n6P8LBjdd9nZ2aioqFDdBqUg/j0gIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpEy26gYWQkQAAKFQSHEnRDSX2efm7HP1btIygG7evAkAqKysVNwJEd1LOByG0+m863haBlBxcTEAYHh4+J47R/FCoRAqKysxMjICTdNUt5MWeMwWRkQQDodRXl5+z7q0DCC7/fOpK6fTyV+KBdA0jcctQTxmiZvPyQEnoYlIGQYQESmTlgHkcDjw4osvwuFwqG4lrfC4JY7H7P6yyb+6TkZEdJ+k5RkQET0YGEBEpAwDiIiUYQARkTJpGUCHDx/GqlWrsGzZMtTU1KC3t1d1S8q0tbVh8+bNKCwsRFlZGbZv346hoaG4msnJSfh8PpSUlKCgoAANDQ0IBAJxNcPDw6ivr0deXh7Kysqwd+9ezMzMJHNXlDl48CBsNhv27NljbeMxSxJJM8ePH5fc3Fz5/e9/L4ODg7Jr1y4pKiqSQCCgujUl6urq5OjRo3Lp0iXp7++X733ve+LxeGR8fNyqeeaZZ6SyslK6u7vlwoULUltbK9/85jet8ZmZGVm3bp14vV75+9//LidPnpTS0lJpbW1VsUtJ1dvbK6tWrZINGzbIs88+a23nMUuOtAugLVu2iM/ns9aj0aiUl5dLW1ubwq5SRzAYFADS09MjIiJjY2OSk5MjHR0dVs2VK1cEgPj9fhEROXnypNjtdjEMw6ppb28XTdMkEokkdweSKBwOy5o1a6Srq0u+/e1vWwHEY5Y8afUSbGpqCn19ffB6vdY2u90Or9cLv9+vsLPUYZomgH9+YLevrw/T09Nxx2zt2rXweDzWMfP7/Vi/fj1cLpdVU1dXh1AohMHBwSR2n1w+nw/19fVxxwbgMUumtPow6o0bNxCNRuN+6ADgcrlw9epVRV2ljlgshj179uCRRx7BunXrAACGYSA3NxdFRUVxtS6XC4ZhWDVzHdPZsQfR8ePHcfHiRXzwwQd3jPGYJU9aBRDdm8/nw6VLl/C3v/1NdSspbWRkBM8++yy6urqwbNky1e1ktLR6CVZaWoqsrKw7rkYEAgG43W5FXaWG5uZmdHZ24r333kNFRYW13e12Y2pqCmNjY3H1Xzxmbrd7zmM6O/ag6evrQzAYxDe+8Q1kZ2cjOzsbPT09ePXVV5GdnQ2Xy8VjliRpFUC5ubmorq5Gd3e3tS0Wi6G7uxu6rivsTB0RQXNzM06cOIHTp09j9erVcePV1dXIycmJO2ZDQ0MYHh62jpmu6xgYGEAwGLRqurq6oGkaqqqqkrMjSbR161YMDAygv7/fWjZt2oQdO3ZY/+YxSxLVs+CJOn78uDgcDjl27JhcvnxZnn76aSkqKoq7GpFJmpqaxOl0yvvvvy+jo6PW8tlnn1k1zzzzjHg8Hjl9+rRcuHBBdF0XXdet8dlLytu2bZP+/n559913ZeXKlRl1SfmLV8FEeMySJe0CSETkt7/9rXg8HsnNzZUtW7bIuXPnVLekDIA5l6NHj1o1t2/fll/84heyYsUKycvLkx/+8IcyOjoadz8ff/yxPP7447J8+XIpLS2VX/7ylzI9PZ3kvVHnywHEY5Yc/HMcRKRMWs0BEdGDhQFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEp83/IWZg6uWdpDwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        over = terminated or truncated\n",
    "\n",
    "        #限制最大步数\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            over = True\n",
    "        \n",
    "        #没坚持到最后,扣分\n",
    "        if over and self.step_n < 200:\n",
    "            reward = -1000\n",
    "\n",
    "        return state, reward, over\n",
    "\n",
    "    #打印游戏图像\n",
    "    def show(self):\n",
    "        from matplotlib import pyplot as plt\n",
    "        plt.figure(figsize=(3, 3))\n",
    "        plt.imshow(self.env.render())\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()\n",
    "\n",
    "env.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Sequential(\n",
       "   (0): Linear(in_features=4, out_features=64, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=64, out_features=64, bias=True)\n",
       "   (3): ReLU()\n",
       "   (4): Linear(in_features=64, out_features=2, bias=True)\n",
       "   (5): Softmax(dim=1)\n",
       " ),\n",
       " Sequential(\n",
       "   (0): Linear(in_features=4, out_features=64, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=64, out_features=64, bias=True)\n",
       "   (3): ReLU()\n",
       "   (4): Linear(in_features=64, out_features=1, bias=True)\n",
       " ))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "#演员模型,计算每个动作的概率\n",
    "model_actor = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 2),\n",
    "    torch.nn.Softmax(dim=1),\n",
    ")\n",
    "\n",
    "#评委模型,计算每个状态的价值\n",
    "model_critic = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 1),\n",
    ")\n",
    "\n",
    "model_critic_delay = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 1),\n",
    ")\n",
    "\n",
    "model_critic_delay.load_state_dict(model_critic.state_dict())\n",
    "\n",
    "model_actor, model_critic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_7590/2154798901.py:34: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  state = torch.FloatTensor(state).reshape(-1, 4)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "-935.0"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "import random\n",
    "\n",
    "\n",
    "#玩一局游戏并记录数据\n",
    "def play(show=False):\n",
    "    state = []\n",
    "    action = []\n",
    "    reward = []\n",
    "    next_state = []\n",
    "    over = []\n",
    "\n",
    "    s = env.reset()\n",
    "    o = False\n",
    "    while not o:\n",
    "        #根据概率采样\n",
    "        prob = model_actor(torch.FloatTensor(s).reshape(1, 4))[0].tolist()\n",
    "        a = random.choices(range(2), weights=prob, k=1)[0]\n",
    "\n",
    "        ns, r, o = env.step(a)\n",
    "\n",
    "        state.append(s)\n",
    "        action.append(a)\n",
    "        reward.append(r)\n",
    "        next_state.append(ns)\n",
    "        over.append(o)\n",
    "\n",
    "        s = ns\n",
    "\n",
    "        if show:\n",
    "            display.clear_output(wait=True)\n",
    "            env.show()\n",
    "\n",
    "    state = torch.FloatTensor(state).reshape(-1, 4)\n",
    "    action = torch.LongTensor(action).reshape(-1, 1)\n",
    "    reward = torch.FloatTensor(reward).reshape(-1, 1)\n",
    "    next_state = torch.FloatTensor(next_state).reshape(-1, 4)\n",
    "    over = torch.LongTensor(over).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over, reward.sum().item()\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over, reward_sum = play()\n",
    "\n",
    "reward_sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer_actor = torch.optim.Adam(model_actor.parameters(), lr=4e-3)\n",
    "optimizer_critic = torch.optim.Adam(model_critic.parameters(), lr=4e-2)\n",
    "\n",
    "\n",
    "def requires_grad(model, value):\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad_(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([66, 1])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_critic(state, reward, next_state, over):\n",
    "    requires_grad(model_actor, False)\n",
    "    requires_grad(model_critic, True)\n",
    "\n",
    "    #计算values和targets\n",
    "    value = model_critic(state)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        target = model_critic_delay(next_state)\n",
    "    target = target * 0.99 * (1 - over) + reward\n",
    "\n",
    "    #时序差分误差,也就是tdloss\n",
    "    loss = torch.nn.functional.mse_loss(value, target)\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer_critic.step()\n",
    "    optimizer_critic.zero_grad()\n",
    "\n",
    "    return value.detach()\n",
    "\n",
    "\n",
    "value = train_critic(state, reward, next_state, over)\n",
    "\n",
    "value.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.008244477212429047"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_actor(state, action, value):\n",
    "    requires_grad(model_actor, True)\n",
    "    requires_grad(model_critic, False)\n",
    "\n",
    "    #重新计算动作的概率\n",
    "    prob = model_actor(state)\n",
    "    prob = prob.gather(dim=1, index=action)\n",
    "\n",
    "    #根据策略梯度算法的导函数实现\n",
    "    #函数中的Q(state,action),这里使用critic模型估算\n",
    "    prob = (prob + 1e-8).log() * value\n",
    "    loss = -prob.mean()\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer_actor.step()\n",
    "    optimizer_actor.zero_grad()\n",
    "\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "train_actor(state, action, value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 -7.496856212615967 -980.05\n",
      "100 -296.3982238769531 -937.9\n",
      "200 -245.56146240234375 -57.95\n",
      "300 -208.5748748779297 -884.4\n",
      "400 -59.092411041259766 200.0\n",
      "500 -46.43401336669922 200.0\n",
      "600 -8.897034645080566 200.0\n",
      "700 -7.917058944702148 200.0\n",
      "800 -1.8229550123214722 200.0\n",
      "900 7.047328948974609 200.0\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model_actor.train()\n",
    "    model_critic.train()\n",
    "\n",
    "    #训练N局\n",
    "    for epoch in range(1000):\n",
    "\n",
    "        #一个epoch最少玩N步\n",
    "        steps = 0\n",
    "        while steps < 200:\n",
    "            state, action, reward, next_state, over, _ = play()\n",
    "            steps += len(state)\n",
    "\n",
    "            #训练两个模型\n",
    "            value = train_critic(state, reward, next_state, over)\n",
    "            loss = train_actor(state, action, value)\n",
    "\n",
    "        #复制参数\n",
    "        for param, param_delay in zip(model_critic.parameters(),\n",
    "                                      model_critic_delay.parameters()):\n",
    "            value = param_delay.data * 0.7 + param.data * 0.3\n",
    "            param_delay.data.copy_(value)\n",
    "\n",
    "        if epoch % 100 == 0:\n",
    "            test_result = sum([play()[-1] for _ in range(20)]) / 20\n",
    "            print(epoch, loss, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAADMCAYAAADTcn7NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUpUlEQVR4nO3dfWxT570H8O9xEjvk5TgEiN1c4hFpvaMZL10DJF7v1aaRJWPR7mjzxzahLqsQXamDSjMhLVeFClQpiEm3HRsN0t0d9B/GlF2xqhltb5bQoN0aAmG5NwSIOrW7SQE7Dcx2CInt+PzuH1XOapJAQoKf2Hw/0pE4z/PY/p0D58t5tTURERARKWBRXQARPbwYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpIyyADp06BBWrFiBzMxMlJWVobOzU1UpRKSIkgD67W9/i/r6erzyyiu4cOEC1q5di6qqKgwODqooh4gU0VQ8jFpWVob169fjl7/8JQDAMAwUFRVhx44d+OlPf5rocohIkfREf2AkEkFXVxcaGhrMNovFgoqKCni93ilfEw6HEQ6HzXnDMHDz5k0sWbIEmqY98JqJaHZEBMPDwygsLITFMv2BVsIDaGhoCLFYDA6HI67d4XDgypUrU76msbERe/fuTUR5RDSPBgYGsHz58mn7Ex5A96OhoQH19fXmfDAYhMvlwsDAAHRdV1gZEU0lFAqhqKgIubm5dx2X8ABaunQp0tLS4Pf749r9fj+cTueUr7HZbLDZbJPadV1nABEtYPc6RZLwq2BWqxWlpaVoa2sz2wzDQFtbG9xud6LLISKFlByC1dfXo7a2FuvWrcOGDRvw+uuvY2RkBM8++6yKcohIESUB9L3vfQ+ffvop9uzZA5/Ph8cffxzvvvvupBPTRJTalNwHNFehUAh2ux3BYJDngIgWoJluo3wWjIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMrMOoBOnz6N73znOygsLISmafj9738f1y8i2LNnDx555BEsWrQIFRUV+PDDD+PG3Lx5E1u2bIGu68jLy8PWrVtx69atOS0IESWfWQfQyMgI1q5di0OHDk3Zf+DAARw8eBCHDx/G2bNnkZ2djaqqKoyNjZljtmzZgt7eXrS2tqKlpQWnT5/Gc889d/9LQUTJSeYAgJw4ccKcNwxDnE6n/OxnPzPbAoGA2Gw2+c1vfiMiIpcuXRIAcu7cOXPMO++8I5qmydWrV2f0ucFgUABIMBicS/lE9IDMdBud13NAH3/8MXw+HyoqKsw2u92OsrIyeL1eAIDX60VeXh7WrVtnjqmoqIDFYsHZs2enfN9wOIxQKBQ3EVHym9cA8vl8AACHwxHX7nA4zD6fz4eCgoK4/vT0dOTn55tj7tTY2Ai73W5ORUVF81k2ESmSFFfBGhoaEAwGzWlgYEB1SUQ0D+Y1gJxOJwDA7/fHtfv9frPP6XRicHAwrn98fBw3b940x9zJZrNB1/W4iYiS37wGUHFxMZxOJ9ra2sy2UCiEs2fPwu12AwDcbjcCgQC6urrMMe3t7TAMA2VlZfNZDhEtcOmzfcGtW7fwl7/8xZz/+OOP0d3djfz8fLhcLuzcuROvvvoqHn30URQXF2P37t0oLCzE5s2bAQCPPfYYvvWtb2Hbtm04fPgwotEo6urq8P3vfx+FhYXztmBElARme3nt1KlTAmDSVFtbKyKfXYrfvXu3OBwOsdlssnHjRunr64t7jxs3bsgPfvADycnJEV3X5dlnn5Xh4eF5v8RHRGrMdBvVREQU5t99CYVCsNvtCAaDPB9EtADNdBtNiqtgRJSaGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMrP+WR4iWlgiIwEMX+sz57W0DOR9YTUsaRkKq5oZBhBRkrt9YwAftf+HOZ+emYscx25Ys/PUFTVDPAQjSjEiBozxiOoyZoQBRJRqxIAxHlVdxYwwgIiSnGZJAzTNnBcjhlh0VGFFM8cAIkpytpwlSLdlm/PGeATh0KcKK5o5BhBRktPSMz7bC/q8JPnFdQYQUZKzpGVA05JzU55V1Y2NjVi/fj1yc3NRUFCAzZs3o6+vL27M2NgYPB4PlixZgpycHNTU1MDv98eN6e/vR3V1NbKyslBQUIBdu3ZhfHx87ktD9BCypE2xBwRAkmAvaFYB1NHRAY/HgzNnzqC1tRXRaBSVlZUYGRkxx7z00kt4++230dzcjI6ODly7dg1PP/202R+LxVBdXY1IJIIPPvgAb775Jo4ePYo9e/bM31IRPUw0Le4kNAAYsST5D13mYHBwUABIR0eHiIgEAgHJyMiQ5uZmc8zly5cFgHi9XhEROXnypFgsFvH5fOaYpqYm0XVdwuHwjD43GAwKAAkGg3MpnygljEfG5H+O/at0Ht5mTte7/0sMw1BW00y30TkdOAaDQQBAfn4+AKCrqwvRaBQVFRXmmJUrV8LlcsHr9QIAvF4vVq9eDYfDYY6pqqpCKBRCb2/vlJ8TDocRCoXiJiKaXiw6prqEGbnvADIMAzt37sSTTz6JVatWAQB8Ph+sVivy8vLixjocDvh8PnPM58Nnon+ibyqNjY2w2+3mVFRUdL9lE6UeTZt0Ejrl7wPyeDy4ePEijh8/Pp/1TKmhoQHBYNCcBgYGHvhnEiULzZKGzPzCuLbRG1cVVTM79/Uwal1dHVpaWnD69GksX77cbHc6nYhEIggEAnF7QX6/H06n0xzT2dkZ934TV8kmxtzJZrPBZrPdT6lEKU+DhnRrVlybGMlxEnpWe0Aigrq6Opw4cQLt7e0oLi6O6y8tLUVGRgba2trMtr6+PvT398PtdgMA3G43enp6MDg4aI5pbW2FrusoKSmZy7IQPZw0wJJuVV3FfZnVHpDH48GxY8fw1ltvITc31zxnY7fbsWjRItjtdmzduhX19fXIz8+HruvYsWMH3G43ysvLAQCVlZUoKSnBM888gwMHDsDn8+Hll1+Gx+PhXg7RfdFgyYjfdhb+HUCfmVUANTU1AQC+/vWvx7UfOXIEP/rRjwAAr732GiwWC2pqahAOh1FVVYU33njDHJuWloaWlhZs374dbrcb2dnZqK2txb59++a2JEQPsUk3IhoGIAagTb5BcSHRRJLgdsk7hEIh2O12BINB6Lquuhwi5T459xauX/iDOZ+17At47F92KTs0m+k2mpwPkBDRXUlsHIYRU13GPTGAiFKAdse8EYtCkuBxDAYQUQrIXPwI8LmbEaMjQYyPjdzlFQsDA4goBaRn5kKL+1bEcYjwEIyIEiAtIzlvYWEAEaUAS7p10ldyJAMGEFEK0NIm39InvApGRKoY0bDqEu6JAUSUgkQEMQYQESWCpmnQ4h67EMTCt5XVM1MMIKIUkG7LhjVn8d8bRDB68xN1Bc0QA4goBWiWdFjSMuLakuExTwYQUQrQLGlTXglb6BhARClAs1gm7QEBC38viAFElAo0CzRL/ObMRzGIKIHu+HHCaAQL/bsRGUBEKcoYjwA8BCOihLjjWbBYdIzngIgoMbLy/yFufvRv1xb8z/MwgIhSRJotO25eYrGFfgqIAUSUKpLxO4EYQEQpwpKRqbqEWWMAEaWIyTciCkQMJbXMFAOIKAVomjbppzHEMCCxqJqCZij5Hh4hekiJCEZGRjA+PvWVrdsj8b+CYcTGEfjbTWREpj4TrWkacnJykJam7tdTGUBESeTHP/4xOjo6puz7py8/gvqnHsfErlAgcBMvVG/CJ0O3phyflZWFP/7xj3C5XA+o2ntjABElkaGhIVy9enXKvku2KP42Wo5PYuUIG1kosv0vtGj7tOOzs7MRi6l9XmxW54CampqwZs0a6LoOXdfhdrvxzjvvmP1jY2PweDxYsmQJcnJyUFNTA7/fH/ce/f39qK6uRlZWFgoKCrBr165pdymJaOYCIwa6Av+Mq+F/xFC0CBdHNmJUc6ou665mFUDLly/H/v370dXVhfPnz+Mb3/gGvvvd76K3txcA8NJLL+Htt99Gc3MzOjo6cO3aNTz99NPm62OxGKqrqxGJRPDBBx/gzTffxNGjR7Fnz575XSqih9BoRBCK6pg4BItKJsJGltqi7kXmaPHixfKrX/1KAoGAZGRkSHNzs9l3+fJlASBer1dERE6ePCkWi0V8Pp85pqmpSXRdl3A4POPPDAaDAkCCweBcyydKGoZhSGVlpeCz+5snTXm52fJvew7KvlfPyN5XO+X1/X+QstVfnnZ8dna2fPTRRw+k1pluo/d9DigWi6G5uRkjIyNwu93o6upCNBpFRUWFOWblypVwuVzwer0oLy+H1+vF6tWr4XA4zDFVVVXYvn07ent78ZWvfGVWNVy5cgU5OTn3uwhESWdkZPrfe789Oob/7vh3DGunEJFMOKz/h2v+gWnHG4aBDz/8EKOjo/Ne561bU5/4vtOsA6inpwdutxtjY2PIycnBiRMnUFJSgu7ublitVuTl5cWNdzgc8Pl8AACfzxcXPhP9E33TCYfDCIf//hMjoVAIABAMBnn+iB4q0ej09/VExmP4z44eAD0zfr/h4WEEAoG5F3aHuwXl5806gL70pS+hu7sbwWAQv/vd71BbWzvtZcH50tjYiL17905qLysrg67rD/SziRYKEZn0H/xcWCwWPPHEEyguLp6395wwsZNwzxpm+8ZWqxVf/OIXUVpaisbGRqxduxY///nP4XQ6EYlEJqWp3++H0/nZmXin0znpqtjE/MSYqTQ0NCAYDJrTwMD0u5VElDzm/CiGYRgIh8MoLS1FRkYG2trazL6+vj709/fD7XYDANxuN3p6ejA4OGiOaW1tha7rKCkpmfYzbDabeel/YiKi5DerQ7CGhgZs2rQJLpcLw8PDOHbsGN5//3289957sNvt2Lp1K+rr65Gfnw9d17Fjxw643W6Ul5cDACorK1FSUoJnnnkGBw4cgM/nw8svvwyPxwObLfm+SoCI5mZWATQ4OIgf/vCHuH79Oux2O9asWYP33nsP3/zmNwEAr732GiwWC2pqahAOh1FVVYU33njDfH1aWhpaWlqwfft2uN1uZGdno7a2Fvv27ZvfpSJKUdnZ2fN2BJCdnQ2LRe3z6JrIAv/S2CmEQiHY7XYEg0EejtFDQ0QwNDSEsbGxeXk/TdPgdDqRnj7/T2TNdBvls2BESULTNCxbtkx1GfOK3wdERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlElXXcD9EBEAQCgUUlwJEU1lYtuc2Fank5QBdOPGDQBAUVGR4kqI6G6Gh4dht9un7U/KAMrPzwcA9Pf333XhKF4oFEJRUREGBgag67rqcpIC19n9EREMDw+jsLDwruOSMoAsls9OXdntdv6juA+6rnO9zRLX2ezNZOeAJ6GJSBkGEBEpk5QBZLPZ8Morr8Bms6kuJalwvc0e19mDpcm9rpMRET0gSbkHRESpgQFERMowgIhIGQYQESmTlAF06NAhrFixApmZmSgrK0NnZ6fqkpRpbGzE+vXrkZubi4KCAmzevBl9fX1xY8bGxuDxeLBkyRLk5OSgpqYGfr8/bkx/fz+qq6uRlZWFgoIC7Nq1C+Pj44lcFGX2798PTdOwc+dOs43rLEEkyRw/flysVqv8+te/lt7eXtm2bZvk5eWJ3+9XXZoSVVVVcuTIEbl48aJ0d3fLt7/9bXG5XHLr1i1zzPPPPy9FRUXS1tYm58+fl/LycvnqV79q9o+Pj8uqVaukoqJC/vznP8vJkydl6dKl0tDQoGKREqqzs1NWrFgha9askRdffNFs5zpLjKQLoA0bNojH4zHnY7GYFBYWSmNjo8KqFo7BwUEBIB0dHSIiEggEJCMjQ5qbm80xly9fFgDi9XpFROTkyZNisVjE5/OZY5qamkTXdQmHw4ldgAQaHh6WRx99VFpbW+VrX/uaGUBcZ4mTVIdgkUgEXV1dqKioMNssFgsqKirg9XoVVrZwBINBAH9/YLerqwvRaDRuna1cuRIul8tcZ16vF6tXr4bD4TDHVFVVIRQKobe3N4HVJ5bH40F1dXXcugG4zhIpqR5GHRoaQiwWi/tLBwCHw4ErV64oqmrhMAwDO3fuxJNPPolVq1YBAHw+H6xWK/Ly8uLGOhwO+Hw+c8xU63SiLxUdP34cFy5cwLlz5yb1cZ0lTlIFEN2dx+PBxYsX8ac//Ul1KQvawMAAXnzxRbS2tiIzM1N1OQ+1pDoEW7p0KdLS0iZdjfD7/XA6nYqqWhjq6urQ0tKCU6dOYfny5Wa70+lEJBJBIBCIG//5deZ0OqdcpxN9qaarqwuDg4N44oknkJ6ejvT0dHR0dODgwYNIT0+Hw+HgOkuQpAogq9WK0tJStLW1mW2GYaCtrQ1ut1thZeqICOrq6nDixAm0t7ejuLg4rr+0tBQZGRlx66yvrw/9/f3mOnO73ejp6cHg4KA5prW1Fbquo6SkJDELkkAbN25ET08Puru7zWndunXYsmWL+WeuswRRfRZ8to4fPy42m02OHj0qly5dkueee07y8vLirkY8TLZv3y52u13ef/99uX79ujndvn3bHPP888+Ly+WS9vZ2OX/+vLjdbnG73Wb/xCXlyspK6e7ulnfffVeWLVv2UF1S/vxVMBGus0RJugASEfnFL34hLpdLrFarbNiwQc6cOaO6JGUATDkdOXLEHDM6OiovvPCCLF68WLKysuSpp56S69evx73PX//6V9m0aZMsWrRIli5dKj/5yU8kGo0meGnUuTOAuM4Sg1/HQUTKJNU5ICJKLQwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhImf8HTSjZr6qZNksAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "200.0"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "play(True)[-1]"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第9章-策略梯度算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
